#!/usr/bin/env python
# MIT License
#
# Deep Joint Demosaicking and Denoising
# Siggraph Asia 2016
# Michael Gharbi, Gaurav Chaurasia, Sylvain Paris, Fredo Durand
# 
# Copyright (c) 2016 Michael Gharbi
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Run the demosaicking network on an image or a directory containing multiple images."""

import argparse
import skimage.io
import numpy as np
import os
import re
import time
import tempfile
from tqdm import tqdm

os.environ['GLOG_minloglevel'] = '2' 
import caffe

NOISE_LEVELS = [0.0000, 0.0784]  # Min/Max noise levels we trained on

def _psnr(a, b, crop=0, maxval=1.0):
    """Computes PSNR on a cropped version of a,b"""

    if crop > 0:
        aa = a[crop:-crop, crop:-crop, :]
        bb = b[crop:-crop, crop:-crop, :]
    else:
        aa = a
        bb = b

    d = np.mean(np.square(aa-bb))
    d = -10*np.log10(d/(maxval*maxval))
    return d


def _uint2float(I):
    I = I.astype(np.float32)
    I = I*0.00390625
    return I


def _float2uint(I):
    I /= 0.00390625
    I += 0.5
    I = np.clip(I,0,255)
    I = I.astype(np.uint8)
    return I


def _blob_to_image(blob):
    # input shape h,w,c
    shape =  blob.data.shape
    sz = shape[1:]
    out = np.copy(blob.data)
    out = np.reshape(out, sz)
    out = out.transpose((1,2,0))
    return out

def _make_mosaic(im, mosaic_type):
    if mosaic_type == 'bayer':
        layer = 'BayerMosaickLayer'
    elif mosaic_type == 'xtrans':
        layer = 'XTransMosaickLayer'
    else:
        raise Exception('Unknown mosaick type {}.'.format(mosaic_type))
    h, w, c = im.shape
    with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
        f.write("""name: 'pythonnet' force_backward: true
        input: 'data' input_shape { dim:1 dim: %d dim: %d dim: %d }
        layer { type: 'Python' name: 'output' bottom: 'data' top: 'output'
          python_param { module: 'demosaicnet.layers' layer: '%s' } }""" % (c, h, w, layer))
        fname = f.name
    net = caffe.Net(fname, caffe.TEST)
    os.remove(fname)
    net.blobs['data'].data[...] = im.transpose([2, 0, 1])[...]
    out = np.squeeze(net.blobs['output'].data).transpose([1, 2, 0])
    net.forward()
    return out


def demosaick(net, M, noise, psize, crop):
    start_time = time.time()
    h,w = M.shape[:2]

    psize = min(min(psize,h),w)
    psize -= psize % 2
    patch_step = psize
    patch_step -= 2*crop
    shift_factor = 2

    # Result array
    R = np.zeros(M.shape, dtype = np.float32)

    rangex = range(0,w-2*crop,patch_step)
    rangey = range(0,h-2*crop,patch_step)
    ntiles = len(rangex)*len(rangey)
    with tqdm(total=ntiles, unit='tiles', unit_scale=True) as pbar:
        for start_x in rangex:
            for start_y in rangey:
                end_x = start_x+psize
                end_y = start_y+psize
                if end_x > w:
                    end_x = w
                    end_x = shift_factor*((end_x)/shift_factor)
                    start_x = end_x-psize
                if end_y > h:
                    end_y = h
                    end_y = shift_factor*((end_y)/shift_factor)
                    start_y = end_y-psize


                tileM = M[start_y:end_y, start_x:end_x, :] 
                tileM = tileM[np.newaxis,:,:,:]
                tileM = tileM.transpose((0,3,1,2))

                net.blobs['mosaick'].reshape(*tileM.shape)
                net.blobs['mosaick'].data[...] = tileM

                if 'noise_level' in net.blobs.keys():
                    noise_shape = [1,]
                    net.blobs['noise_level'].reshape(*noise_shape)
                    net.blobs['noise_level'].data[...] = noise

                net.forward()

                out = net.blobs['output']
                out = _blob_to_image(out)
                s = out.shape[0]

                R[start_y+crop:start_y+crop+s,
                  start_x+crop:start_x+crop+s,:] = out

                pbar.update(1)

    R[R<0] = 0.0
    R[R>1] = 1.0

    runtime = (time.time()-start_time)*1000  # in ms

    return R, runtime

def main(args):
    arch_path = os.path.join(args.model, 'deploy.prototxt')
    weights_path = os.path.join(args.model, 'weights.caffemodel')
   # if args.gpu:
    print '  - using GPU'
    caffe.set_mode_gpu()
   # else:
   #     print '  - using CPU'
   #     caffe.set_mode_cpu()
    net = caffe.Net(arch_path, weights_path, caffe.TEST)

    crop = (net.blobs['mosaick'].data.shape[-1]
            - net.blobs['output'].data.shape[-1])/2

    print "Crop", crop

    regexp = re.compile(r".*\.(png|tif)")
    if os.path.isdir(args.input):
        print 'dir'
        inputs = [f for f in os.listdir(args.input) if regexp.match(f)]
        inputs = [os.path.join(args.input, f) for f in inputs]
    else:
        inputs = [args.input]

    avg_psnr = 0
    n = 0
    for fname in inputs:
        print '+ Processing {}'.format(fname)
        Iref = skimage.io.imread(fname)
        dtype = Iref.dtype
        if dtype not in [np.uint8, np.uint16]:
            raise ValueError('Input type not handled: {}'.format(dtype))
        Iref = skimage.img_as_float(Iref)

        if len(Iref.shape) == 2:
            # Offset the image to match the our mosaic pattern
            if args.offset_x > 0:
                print '  - offset x'
                # Iref = Iref[:, 1:]
                Iref = np.pad(Iref, [(0, 0), (args.offset_x, 0)], 'reflect')

            if args.offset_y > 0:
                print '  - offset y'
                # Iref = Iref[1:, :]
                Iref = np.pad(Iref, [(args.offset_y, 0), (0,0)], 'reflect')
            has_groundtruth = False
            Iref = np.dstack((Iref, Iref, Iref))
        else:
            # No need for offsets if we have the ground-truth
            has_groundtruth = True

        if has_groundtruth and args.noise > 0:
            print '  - adding noise sigma={:.3f}'.format(args.noise)
            I = Iref + np.random.normal(
                    loc=0.0, scale = args.noise , size = Iref.shape )
        else:
            I = Iref

        if crop > 0:
            if args.mosaic_type == 'bayer':
                c = crop + (crop %2)  # Make sure we don't change the pattern's period
                I = np.pad(I, [(c, c), (c, c), (0, 0)], 'reflect')
            else:
                c = crop + (crop % 6)  # Make sure we don't change the pattern's period
                I = np.pad(I, [(c, c), (c, c), (0, 0)], 'reflect')

        if has_groundtruth:
            print '  - making mosaick'
        else:
            print '  - formatting mosaick'
        M = _make_mosaic(I, args.mosaic_type)

        R, runtime = demosaick(net, M, args.noise, args.tile_size, crop)

        if crop > 0:
            R = R[c:-c, c:-c, :]
            I = I[c:-c, c:-c, :]
            M = M[c:-c, c:-c, :]

        if not has_groundtruth:
            if args.offset_x > 0:
                print '  - remove offset x'
                R = R[:, args.offset_x:]
                I = I[:, args.offset_x:]
                M = M[:, args.offset_x:]

            if args.offset_y > 0:
                print '  - remove offset y'
                R = R[args.offset_y:, :]
                I = I[args.offset_y:, :]
                M = M[args.offset_y:, :]

        if len(Iref.shape) == 2:
            # Offset the image to match the our mosaic pattern
            if args.offset_x == 1:
                print '  - offset x'
                Iref = Iref[:, 1:]

            if args.offset_y == 1:
                print '  - offset y'
                Iref = Iref[1:, :]
            has_groundtruth = False

        if has_groundtruth:
            p = _psnr(R, Iref, crop=crop)
            avg_psnr += p
            n += 1
            diff = np.abs((R-Iref))
            diff /= np.amax(diff)
            out = np.hstack((Iref, I, M, R, diff))
            out = _float2uint(out)
            print '  PSNR = {:.1f} dB, time = {} ms'.format(p, int(runtime))
        else:
            print '  - raw image without groundtruth, bypassing metric'
            if dtype == np.uint16:
                out = skimage.img_as_uint(R)
            else:
                out = skimage.img_as_ubyte(R)

        outputname = os.path.join(args.output, 'dj_output.jpg')
        skimage.io.imsave(outputname, out)

    if has_groundtruth and n > 0:
        avg_psnr /= n
        print '+ Average PSNR = {:.1f} dB'.format(avg_psnr)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument('--input', type=str, default='/Users/tomheaven/Documents/Current/demosaicnet-master/data/test_images/000003.png', help='path to input image or folder.')
    parser.add_argument('--output', type=str, default='/Users/tomheaven/Documents/Current/demosaicnet-master/output', help='path to output folder.')
    parser.add_argument('--model', type=str, default='/Users/tomheaven/Documents/Current/demosaicnet-master/pretrained_models/bayer', help='path to trained model (folder containing deploy.prototxt and weights.caffemodel).')
    parser.add_argument('--noise', type=float, default=0.0, help='standard deviation of additive Gaussian noise, w.r.t to a [0,1] intensity scale.')
    parser.add_argument('--offset_x', type=int, default=0, help='number of pixels to offset the mosaick in the x-axis.')
    parser.add_argument('--offset_y', type=int, default=0, help='number of pixels to offset the mosaick in the y-axis.')
    parser.add_argument('--tile_size', type=int, default=512, help='split the input into tiles of this size.')
    parser.add_argument('--gpu', dest='gpu', action='store_true', help='use the GPU for processing.')
    parser.add_argument('--mosaic_type', type=str, default='bayer', choices=['bayer', 'xtrans'], help='type of mosaick (xtrans or bayer)')

    parser.set_defaults(gpu=False)

    args = parser.parse_args()

    if args.noise > NOISE_LEVELS[1] or args.noise < NOISE_LEVELS[0]:
        msg = 'The model was trained on noise levels in [{}, {}]'.format(
                NOISE_LEVELS[0], NOISE_LEVELS[1])
        raise ValueError(msg)

    main(args)
